"""
Envelope construction for the kernel‑to‑metric simulation.

Given fractal‑dimension anchors and pivot parameters, this module builds
a radially symmetric envelope ``E0(x, y)`` on a square lattice.  The
envelope is derived from the pivot function ``g(D) = a·D + b`` and the
context levels ``n`` from ``D_values.csv``.  We treat the context levels
as radial shells and interpolate the pivot weights onto a dense grid.

The steps are:

1. Load ``n`` and ``D`` arrays from ``data/D_values.csv``.
2. Compute pivot weights ``g(D)`` using ``a`` and ``b`` from
   ``data/pivot_params.json``.
3. Rescale the context levels to lie in ``[0, 1]`` and use them to
   interpolate the pivot weights radially.
4. Generate a lattice ``E0`` by mapping each pixel's radial distance to
   an interpolated pivot weight.
5. Smooth ``E0`` with a Gaussian filter of width ``ell`` (in lattice
   units).
6. Compute the gradient magnitude of the smoothed envelope and
   normalise it by its spatial mean to obtain ``G_hat``.

The normalised gradient magnitude ``G_hat`` is the key quantity fed
into the Poisson and optical translators.
"""

from __future__ import annotations

import numpy as np
import os
from scipy.ndimage import gaussian_filter
from typing import Tuple

from .io_fphs import load_D_values, load_pivot_params


def _prepare_radial_interpolant(n_vals: np.ndarray, g_vals: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    """Prepare arrays for 1‑D radial interpolation.

    The context levels ``n_vals`` are rescaled to the range ``[0, 1]``.
    The corresponding pivot weights ``g_vals`` are returned in the same
    sorted order.  The caller can pass these arrays directly to
    ``numpy.interp``.

    Parameters
    ----------
    n_vals : ndarray
        Context indices (floats).  Must be one‑dimensional.
    g_vals : ndarray
        Pivot weights corresponding to each context level.

    Returns
    -------
    r_frac_sorted : ndarray
        Rescaled and sorted context levels in ``[0, 1]``.
    g_sorted : ndarray
        Pivot weights sorted according to the rescaled context levels.
    """
    # Sort by n to ensure monotonic interpolation
    sort_idx = np.argsort(n_vals)
    n_sorted = n_vals[sort_idx]
    g_sorted = g_vals[sort_idx]
    # Rescale n to [0, 1]
    n_min = float(n_sorted.min())
    n_max = float(n_sorted.max())
    # Avoid divide by zero if only a single level
    if n_max - n_min == 0:
        r_frac = np.zeros_like(n_sorted, dtype=float)
    else:
        r_frac = (n_sorted - n_min) / (n_max - n_min)
    return r_frac, g_sorted


def _radial_grid(L: int) -> np.ndarray:
    """Compute the radial coordinate (normalised) for each lattice point.

    The radial distance is computed from the centre of the lattice and
    normalised to lie in ``[0, 1]`` by dividing by ``(L/2)``.

    Parameters
    ----------
    L : int
        Lattice size (length of one dimension).

    Returns
    -------
    r_grid : ndarray
        ``L × L`` array of radial coordinates in ``[0, 1]``.
    """
    coords = np.arange(L, dtype=float) - (L - 1) / 2.0
    x, y = np.meshgrid(coords, coords, indexing="ij")
    r = np.sqrt(x**2 + y**2)
    r_norm = r / (L / 2.0)
    # Clip to [0, 1] so that interpolation outside the original range
    # yields constant values at the boundary
    r_norm = np.clip(r_norm, 0.0, 1.0)
    return r_norm


def build_radial_envelope(
    L: int,
    n_vals: np.ndarray,
    g_vals: np.ndarray,
    gauge_factor: float = 1.0,
    decay_alpha: float | None = None,
) -> np.ndarray:
    """Construct a radially symmetric envelope ``E0`` on a square lattice.

    Parameters
    ----------
    L : int
        Lattice size.
    n_vals : ndarray
        Context indices from ``D_values.csv``.
    g_vals : ndarray
        Pivot weights computed from fractal dimensions and pivot params.
    gauge_factor : float, optional
        Additional multiplicative factor applied to the envelope to model
        gauge‑dependent amplitude variations.  Defaults to 1.0.

    Parameters
    ----------
    L : int
        Lattice size.
    n_vals : ndarray
        Context indices from ``D_values.csv``.
    g_vals : ndarray
        Pivot weights computed from fractal dimensions and pivot params.
    gauge_factor : float, optional
        Additional multiplicative factor applied to the envelope to model
        gauge‑dependent amplitude variations.  Defaults to 1.0.
    decay_alpha : float or None, optional
        If provided, a radial decay factor of the form
        ``exp(-decay_alpha * r_norm**2)`` is applied to the envelope to
        localise the source.  A higher ``decay_alpha`` yields a more
        concentrated core.  If ``None``, no decay factor is applied.

    Returns
    -------
    E0 : ndarray
        ``L × L`` array containing the radially interpolated pivot weights.
    """
    r_frac, g_sorted = _prepare_radial_interpolant(n_vals, g_vals)
    r_grid = _radial_grid(L)
    # Perform 1‑D linear interpolation along the radial coordinate.
    # Values outside the interpolation range are clipped at the
    # boundaries because ``r_grid`` is clipped to [0, 1].
    E0 = np.interp(r_grid, r_frac, g_sorted)
    # Optionally apply a radial decay to localise the envelope.  Without a
    # decay factor, the radially interpolated weights extend to the
    # boundary, which produces a slowly varying source and fails to
    # generate an inverse‑power potential.  The decay factor acts like
    # a correlation length in the underlying kernel and ensures the far
    # field behaves like 1/r.
    if decay_alpha is not None and decay_alpha > 0.0:
        # Compute Gaussian decay profile; r_grid is in [0,1]
        decay = np.exp(-decay_alpha * (r_grid**2))
        E0 *= decay
    # Apply gauge factor (allows SU(2)/SU(3) to differ in amplitude)
    E0 *= gauge_factor
    return E0


def smooth_and_gradient(E0: np.ndarray, ell: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Smooth a 2‑D envelope and compute its gradient magnitude.

    A Gaussian filter of width ``ell`` is applied to ``E0``.  The
    gradient of the smoothed envelope is then computed using central
    differences.  Both the smoothed envelope and the gradient magnitude
    are returned; the gradient magnitude can be further normalised by
    its spatial mean.

    Parameters
    ----------
    E0 : ndarray
        Raw envelope (``L × L``).
    ell : int
        Smoothing width (Gaussian sigma) in lattice units.  Must be
        positive.

    Returns
    -------
    E_smooth : ndarray
        Smoothed envelope of the same shape as ``E0``.
    grad_x : ndarray
        Gradient of ``E_smooth`` along the first axis.
    grad_y : ndarray
        Gradient of ``E_smooth`` along the second axis.
    """
    if ell <= 0:
        # If ell is zero or negative, no smoothing is applied
        E_smooth = E0.copy()
    else:
        E_smooth = gaussian_filter(E0, sigma=ell, mode="reflect")
    # Compute gradients using NumPy's central difference
    grad_x, grad_y = np.gradient(E_smooth)
    return E_smooth, grad_x, grad_y


def normalise_gradient_magnitude(grad_x: np.ndarray, grad_y: np.ndarray) -> np.ndarray:
    """Compute and normalise the gradient magnitude of a field.

    The gradient magnitude ``G = sqrt(grad_x**2 + grad_y**2)`` is
    computed and then divided by its spatial mean.  This normalised
    gradient magnitude ``G_hat`` has unit mean and isolates the shape of
    the kernel envelope independent of its absolute scale.

    Parameters
    ----------
    grad_x : ndarray
        Gradient along the first axis.
    grad_y : ndarray
        Gradient along the second axis.

    Returns
    -------
    G_hat : ndarray
        Normalised gradient magnitude.
    """
    G = np.sqrt(grad_x**2 + grad_y**2)
    mean_val = G.mean()
    if mean_val == 0:
        # Avoid division by zero; return zeros
        return np.zeros_like(G)
    return G / mean_val


def build_envelope(
    L: int,
    ell: int,
    data_dir: str,
    gauge_factor: float = 1.0,
    decay_alpha: float | None = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """Construct a smoothed, normalised kernel envelope for a given lattice.

    This convenience wrapper loads fractal anchors and pivot params from
    ``data_dir``, builds a radially symmetric envelope ``E0`` of size
    ``L × L``, smoothes it with a Gaussian of width ``ell``, computes
    its gradient and normalises the magnitude.

    Parameters
    ----------
    L : int
        Lattice size (must be positive).
    ell : int
        Smoothing width (Gaussian sigma) in lattice units (must be
        non‑negative).  When ``ell=0`` no smoothing is applied.
    data_dir : str
        Directory containing ``D_values.csv`` and ``pivot_params.json``.
    gauge_factor : float, optional
        Amplitude multiplier to distinguish different gauge groups.  A
        value of 1.0 uses the raw pivot weights; other values scale
        the entire envelope.

    Additional Parameters
    --------------------
    decay_alpha : float or None, optional
        Radial decay factor used to localise the envelope (see
        ``build_radial_envelope``).  If ``None`` (default), no decay is
        applied.

    Returns
    -------
    E0 : ndarray
        Raw radial envelope (``L × L``).
    E_smooth : ndarray
        Smoothed envelope of the same shape.
    grad_x : ndarray
        Gradient of the smoothed envelope along the first axis.
    grad_y : ndarray
        Gradient of the smoothed envelope along the second axis.
    G_hat : ndarray
        Normalised gradient magnitude (shape ``L × L``).
    """
    # Load data
    n_vals, D_vals = load_D_values(os.path.join(data_dir, "D_values.csv"))
    a, b = load_pivot_params(os.path.join(data_dir, "pivot_params.json"))
    # Compute pivot weights g(D) = a D + b
    g_vals = a * D_vals + b
    # Build raw radial envelope with optional decay
    E0 = build_radial_envelope(
        L,
        n_vals,
        g_vals,
        gauge_factor=gauge_factor,
        decay_alpha=decay_alpha,
    )
    # Smooth and compute gradients
    E_smooth, grad_x, grad_y = smooth_and_gradient(E0, ell)
    # Normalise gradient magnitude
    G_hat = normalise_gradient_magnitude(grad_x, grad_y)
    return E0, E_smooth, grad_x, grad_y, G_hat
